{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Copyright 2021-2022 @ Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd\n",
    "\n",
    "This code is a part of Cybertron package.\n",
    "\n",
    "The Cybertron is open-source software based on the AI-framework:\n",
    "MindSpore (https://www.mindspore.cn/)\n",
    "\n",
    "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "you may not use this file except in compliance with the License.\n",
    "\n",
    "You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0\n",
    "\n",
    "Unless required by applicable law or agreed to in writing, software\n",
    "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "\n",
    "See the License for the specific language governing permissions and\n",
    "limitations under the License.\n",
    "\n",
    "Cybertron tutorial 06: Multi-task with multiple readouts (example 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[WARNING] ME(15775:140549731665728,MainProcess):2022-08-10-17:33:36.300.071 [mindspore/run_check/_check_version.py:137] Can not found cuda libs, please confirm that the correct cuda version has been installed, you can refer to the installation guidelines: https://www.mindspore.cn/install\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "import time\n",
    "import numpy as np\n",
    "import mindspore as ms\n",
    "from mindspore import nn\n",
    "from mindspore import Tensor\n",
    "from mindspore import context\n",
    "from mindspore import dataset as ds\n",
    "from mindspore.train import Model\n",
    "from mindspore.train.callback import ModelCheckpoint, CheckpointConfig\n",
    "\n",
    "from cybertron import Cybertron\n",
    "from cybertron import MolCT\n",
    "from cybertron.train import MAE, MLoss\n",
    "from cybertron.train import WithLabelLossCell, WithLabelEvalCell\n",
    "from cybertron.train import TrainMonitor\n",
    "from cybertron.train import TransformerLR\n",
    "\n",
    "context.set_context(mode=context.GRAPH_MODE, device_target=\"GPU\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_name = sys.path[0] + '/dataset_qm9_normed_'\n",
    "train_file = data_name + 'trainset_1024.npz'\n",
    "valid_file = data_name + 'validset_128.npz'\n",
    "\n",
    "train_data = np.load(train_file)\n",
    "valid_data = np.load(valid_file)\n",
    "\n",
    "# diplole,polarizability,HOMO,LUMO,gap,R2,zpve,capacity\n",
    "idx = [0, 1, 2, 3, 4, 5, 6, 11]\n",
    "\n",
    "num_atom = int(train_data['num_atoms'])\n",
    "scale = Tensor(train_data['scale'][idx], ms.float32)\n",
    "shift = Tensor(train_data['shift'][idx], ms.float32)\n",
    "ref = Tensor(train_data['type_ref'][:, idx], ms.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "mod = MolCT(\n",
    "    cutoff=1,\n",
    "    n_interaction=3,\n",
    "    dim_feature=128,\n",
    "    n_heads=8,\n",
    "    activation='swish',\n",
    "    max_cycles=1,\n",
    "    length_unit='nm',\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "================================================================================\n",
      "Cybertron Engine, Ride-on!\n",
      "--------------------------------------------------------------------------------\n",
      "    Length unit: nm\n",
      "    Input unit scale: 1\n",
      "--------------------------------------------------------------------------------\n",
      "    Deep molecular model:  MolCT\n",
      "--------------------------------------------------------------------------------\n",
      "       Length unit: nm\n",
      "       Atom embedding size: 64\n",
      "       Cutoff distance: 1.0 nm\n",
      "       Radical basis function (RBF): LogGaussianBasis\n",
      "          Minimum distance: 0.04 nm\n",
      "          Maximum distance: 1.0 nm\n",
      "          Reference distance: 1.0 nm\n",
      "          Log Gaussian begin: -3.218876\n",
      "          Log Gaussian end: 0.006724119\n",
      "          Interval for log Gaussian: 0.0512\n",
      "          Sigma for log gaussian: 0.3\n",
      "          Number of basis functions: 64\n",
      "          Rescale the range of RBF to (-1,1).\n",
      "       Calculate distance: Yes\n",
      "       Calculate bond: No\n",
      "       Feature dimension: 128\n",
      "--------------------------------------------------------------------------------\n",
      "       Using 3 independent interaction layers:\n",
      "--------------------------------------------------------------------------------\n",
      "       0. Neural Interaction Unit\n",
      "          Feature dimension: 128\n",
      "          Activation function: Swish\n",
      "          Encoding distance: Yes\n",
      "          Encoding bond: No\n",
      "          Number of heads in multi-haed attention: 8\n",
      "          Use feed forward network: No\n",
      "--------------------------------------------------------------------------------\n",
      "       1. Neural Interaction Unit\n",
      "          Feature dimension: 128\n",
      "          Activation function: Swish\n",
      "          Encoding distance: Yes\n",
      "          Encoding bond: No\n",
      "          Number of heads in multi-haed attention: 8\n",
      "          Use feed forward network: No\n",
      "--------------------------------------------------------------------------------\n",
      "       2. Neural Interaction Unit\n",
      "          Feature dimension: 128\n",
      "          Activation function: Swish\n",
      "          Encoding distance: Yes\n",
      "          Encoding bond: No\n",
      "          Number of heads in multi-haed attention: 8\n",
      "          Use feed forward network: No\n",
      "--------------------------------------------------------------------------------\n",
      "    With 6 readout networks: \n",
      "--------------------------------------------------------------------------------\n",
      "    0. GraphReadout\n",
      "       Activation function: Swish\n",
      "       Decoder: HalveDecoder\n",
      "       Aggregator: TensorMean\n",
      "       Representation dimension: 128\n",
      "       Readout dimension: 1\n",
      "       Scale: 1.0\n",
      "       Shift: 0.0\n",
      "       Scaleshift mode: Graph\n",
      "       Reference value for atom types: None\n",
      "       Output unit: None\n",
      "       Reduce axis: -2\n",
      "--------------------------------------------------------------------------------\n",
      "    1. GraphReadout\n",
      "       Activation function: Swish\n",
      "       Decoder: HalveDecoder\n",
      "       Aggregator: TensorMean\n",
      "       Representation dimension: 128\n",
      "       Readout dimension: 1\n",
      "       Scale: 1.0\n",
      "       Shift: 0.0\n",
      "       Scaleshift mode: Graph\n",
      "       Reference value for atom types: None\n",
      "       Output unit: None\n",
      "       Reduce axis: -2\n",
      "--------------------------------------------------------------------------------\n",
      "    2. GraphReadout\n",
      "       Activation function: Swish\n",
      "       Decoder: HalveDecoder\n",
      "       Aggregator: TensorMean\n",
      "       Representation dimension: 128\n",
      "       Readout dimension: 3\n",
      "       Scale: 1.0\n",
      "       Shift: 0.0\n",
      "       Scaleshift mode: Graph\n",
      "       Reference value for atom types: None\n",
      "       Output unit: None\n",
      "       Reduce axis: -2\n",
      "--------------------------------------------------------------------------------\n",
      "    3. GraphReadout\n",
      "       Activation function: Swish\n",
      "       Decoder: HalveDecoder\n",
      "       Aggregator: TensorMean\n",
      "       Representation dimension: 128\n",
      "       Readout dimension: 1\n",
      "       Scale: 1.0\n",
      "       Shift: 0.0\n",
      "       Scaleshift mode: Graph\n",
      "       Reference value for atom types: None\n",
      "       Output unit: None\n",
      "       Reduce axis: -2\n",
      "--------------------------------------------------------------------------------\n",
      "    4. GraphReadout\n",
      "       Activation function: Swish\n",
      "       Decoder: HalveDecoder\n",
      "       Aggregator: TensorMean\n",
      "       Representation dimension: 128\n",
      "       Readout dimension: 1\n",
      "       Scale: 1.0\n",
      "       Shift: 0.0\n",
      "       Scaleshift mode: Graph\n",
      "       Reference value for atom types: None\n",
      "       Output unit: None\n",
      "       Reduce axis: -2\n",
      "--------------------------------------------------------------------------------\n",
      "    5. GraphReadout\n",
      "       Activation function: Swish\n",
      "       Decoder: HalveDecoder\n",
      "       Aggregator: TensorMean\n",
      "       Representation dimension: 128\n",
      "       Readout dimension: 1\n",
      "       Scale: 1.0\n",
      "       Shift: 0.0\n",
      "       Scaleshift mode: Graph\n",
      "       Reference value for atom types: None\n",
      "       Output unit: None\n",
      "       Reduce axis: -2\n",
      "--------------------------------------------------------------------------------\n",
      "    Output dimension: [1 1 3 1 1 1]\n",
      "    Total output dimension: 8\n",
      "    Output unit for Cybertron: None\n",
      "    Output unit scale: [1. 1. 1. 1. 1. 1. 1. 1.]\n",
      "================================================================================\n"
     ]
    }
   ],
   "source": [
    "net = Cybertron(mod, readout='graph', dim_output=[1, 1, 3, 1, 1, 1],\n",
    "                num_atoms=num_atom, length_unit='nm')\n",
    "net.print_info()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 model.atom_embedding.embedding_table (64, 128)\n",
      "1 model.dis_filter.linear.weight (128, 64)\n",
      "2 model.dis_filter.linear.bias (128,)\n",
      "3 model.dis_filter.residual.nonlinear.mlp.0.weight (128, 128)\n",
      "4 model.dis_filter.residual.nonlinear.mlp.0.bias (128,)\n",
      "5 model.dis_filter.residual.nonlinear.mlp.1.weight (128, 128)\n",
      "6 model.dis_filter.residual.nonlinear.mlp.1.bias (128,)\n",
      "7 model.interactions.0.positional_embedding.norm.gamma (128,)\n",
      "8 model.interactions.0.positional_embedding.norm.beta (128,)\n",
      "9 model.interactions.0.positional_embedding.x2q.weight (128, 128)\n",
      "10 model.interactions.0.positional_embedding.x2k.weight (128, 128)\n",
      "11 model.interactions.0.positional_embedding.x2v.weight (128, 128)\n",
      "12 model.interactions.0.multi_head_attention.output.weight (128, 128)\n",
      "13 model.interactions.1.positional_embedding.norm.gamma (128,)\n",
      "14 model.interactions.1.positional_embedding.norm.beta (128,)\n",
      "15 model.interactions.1.positional_embedding.x2q.weight (128, 128)\n",
      "16 model.interactions.1.positional_embedding.x2k.weight (128, 128)\n",
      "17 model.interactions.1.positional_embedding.x2v.weight (128, 128)\n",
      "18 model.interactions.1.multi_head_attention.output.weight (128, 128)\n",
      "19 model.interactions.2.positional_embedding.norm.gamma (128,)\n",
      "20 model.interactions.2.positional_embedding.norm.beta (128,)\n",
      "21 model.interactions.2.positional_embedding.x2q.weight (128, 128)\n",
      "22 model.interactions.2.positional_embedding.x2k.weight (128, 128)\n",
      "23 model.interactions.2.positional_embedding.x2v.weight (128, 128)\n",
      "24 model.interactions.2.multi_head_attention.output.weight (128, 128)\n",
      "25 readout.0.decoder.output.mlp.0.weight (64, 128)\n",
      "26 readout.0.decoder.output.mlp.0.bias (64,)\n",
      "27 readout.0.decoder.output.mlp.1.weight (1, 64)\n",
      "28 readout.0.decoder.output.mlp.1.bias (1,)\n",
      "29 readout.1.decoder.output.mlp.0.weight (64, 128)\n",
      "30 readout.1.decoder.output.mlp.0.bias (64,)\n",
      "31 readout.1.decoder.output.mlp.1.weight (1, 64)\n",
      "32 readout.1.decoder.output.mlp.1.bias (1,)\n",
      "33 readout.2.decoder.output.mlp.0.weight (64, 128)\n",
      "34 readout.2.decoder.output.mlp.0.bias (64,)\n",
      "35 readout.2.decoder.output.mlp.1.weight (3, 64)\n",
      "36 readout.2.decoder.output.mlp.1.bias (3,)\n",
      "37 readout.3.decoder.output.mlp.0.weight (64, 128)\n",
      "38 readout.3.decoder.output.mlp.0.bias (64,)\n",
      "39 readout.3.decoder.output.mlp.1.weight (1, 64)\n",
      "40 readout.3.decoder.output.mlp.1.bias (1,)\n",
      "41 readout.4.decoder.output.mlp.0.weight (64, 128)\n",
      "42 readout.4.decoder.output.mlp.0.bias (64,)\n",
      "43 readout.4.decoder.output.mlp.1.weight (1, 64)\n",
      "44 readout.4.decoder.output.mlp.1.bias (1,)\n",
      "45 readout.5.decoder.output.mlp.0.weight (64, 128)\n",
      "46 readout.5.decoder.output.mlp.0.bias (64,)\n",
      "47 readout.5.decoder.output.mlp.1.weight (1, 64)\n",
      "48 readout.5.decoder.output.mlp.1.bias (1,)\n",
      "Total parameters:  296968\n"
     ]
    }
   ],
   "source": [
    "tot_params = 0\n",
    "for i, param in enumerate(net.get_parameters()):\n",
    "    tot_params += param.size\n",
    "    print(i, param.name, param.shape)\n",
    "print('Total parameters: ', tot_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_epoch = 8\n",
    "repeat_time = 1\n",
    "batch_size = 32"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds_train = ds.NumpySlicesDataset(\n",
    "    {'R': train_data['R'], 'Z': train_data['Z'], 'E': train_data['E'][:, idx]}, shuffle=True)\n",
    "ds_train = ds_train.batch(batch_size, drop_remainder=True)\n",
    "ds_train = ds_train.repeat(repeat_time)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds_valid = ds.NumpySlicesDataset(\n",
    "    {'R': valid_data['R'], 'Z': valid_data['Z'], 'E': valid_data['E'][:, idx]}, shuffle=False)\n",
    "ds_valid = ds_valid.batch(128)\n",
    "ds_valid = ds_valid.repeat(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WithLabelLossCell with input type: RZE\n",
      "WithLabelEvalCell with input type: RZE\n",
      "   with scaleshift for training and evaluate dataset:\n",
      "   Output.            Scale           Shift        Mode\n",
      "   0:        1.503183e+00    2.672827e+00       graph\n",
      "   1:        8.173762e+00    7.528103e+01       graph\n",
      "   2:        5.767056e+01   -6.306687e+02       graph\n",
      "   3:        1.229870e+02    3.108809e+01       graph\n",
      "   4:        1.238940e+02    6.617564e+02       graph\n",
      "   5:        2.804632e+02    1.189402e+03       graph\n",
      "   6:        8.699451e+01    3.914438e+02       graph\n",
      "   7:        6.082039e+00   -2.213512e+01       graph\n",
      "   with reference value for atom types:\n",
      "   Type     Label0    Label1    Label2    Label3    Label4    Label5    Label6    Label7\n",
      "   0:        0.00e+00  0.00e+00  0.00e+00  0.00e+00  0.00e+00  0.00e+00  0.00e+00  0.00e+00\n",
      "   1:        0.00e+00  0.00e+00  0.00e+00  0.00e+00  0.00e+00  0.00e+00  0.00e+00  2.98e+00\n",
      "   2:        0.00e+00  0.00e+00  0.00e+00  0.00e+00  0.00e+00  0.00e+00  0.00e+00  0.00e+00\n",
      "   3:        0.00e+00  0.00e+00  0.00e+00  0.00e+00  0.00e+00  0.00e+00  0.00e+00  0.00e+00\n",
      "   4:        0.00e+00  0.00e+00  0.00e+00  0.00e+00  0.00e+00  0.00e+00  0.00e+00  0.00e+00\n",
      "   5:        0.00e+00  0.00e+00  0.00e+00  0.00e+00  0.00e+00  0.00e+00  0.00e+00  0.00e+00\n",
      "   6:        0.00e+00  0.00e+00  0.00e+00  0.00e+00  0.00e+00  0.00e+00  0.00e+00  2.98e+00\n",
      "   7:        0.00e+00  0.00e+00  0.00e+00  0.00e+00  0.00e+00  0.00e+00  0.00e+00  2.98e+00\n",
      "   8:        0.00e+00  0.00e+00  0.00e+00  0.00e+00  0.00e+00  0.00e+00  0.00e+00  2.98e+00\n",
      "   9:        0.00e+00  0.00e+00  0.00e+00  0.00e+00  0.00e+00  0.00e+00  0.00e+00  2.98e+00\n"
     ]
    }
   ],
   "source": [
    "loss_network = WithLabelLossCell('RZE', net, nn.MAELoss())\n",
    "eval_network = WithLabelEvalCell('RZE', net, nn.MAELoss(), scale=scale, shift=shift, type_ref=ref)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr = TransformerLR(learning_rate=1., warmup_steps=4000, dimension=128)\n",
    "optim = nn.Adam(params=net.trainable_params(), learning_rate=lr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_mae = 'EvalMAE'\n",
    "atom_mae = 'AtomMAE'\n",
    "eval_loss = 'Evalloss'\n",
    "model = Model(loss_network, optimizer=optim, eval_network=eval_network,\n",
    "              metrics={eval_mae: MAE([1, 2], reduce_all_dims=False),\n",
    "                       atom_mae: MAE([1, 2, 3], reduce_all_dims=False, averaged_by_atoms=True),\n",
    "                       eval_loss: MLoss(0)},)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "outdir = 'Tutorial_C06'\n",
    "outname = outdir + '_' + net.model_name\n",
    "record_cb = TrainMonitor(model, outname, per_step=32, avg_steps=32,\n",
    "                         directory=outdir, eval_dataset=ds_valid, best_ckpt_metrics=eval_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "config_ck = CheckpointConfig(save_checkpoint_steps=32, keep_checkpoint_max=64, append_info=[net.hyper_param])\n",
    "ckpoint_cb = ModelCheckpoint(prefix=outname, directory=outdir, config=config_ck)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[WARNING] ME(15775:140549731665728,MainProcess):2022-08-10-17:33:54.865.42 [mindspore/train/model.py:1097] For TrainMonitor callback, {'begin', 'epoch_end', 'step_end'} methods may not be supported in later version, Use methods prefixed with 'on_train' or 'on_eval' instead when using customized callbacks.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Start training ...\n",
      "Epoch: 1, Step: 32, Learning_rate: 1.0830951e-05, Last_Loss: 0.8817136, Avg_loss: 0.9816328808665276, EvalMAE: [  1.354305    7.0573864  73.874756  101.67707    95.05714   201.4472     63.426956    4.721528 ], AtomMAE: [ 0.07720656  0.42350703  4.1470666   6.198917    5.4334555  10.866326    3.9789188   0.2760198 ], Evalloss: 0.8578746914863586\n",
      "Epoch: 2, Step: 64, Learning_rate: 2.2011289e-05, Last_Loss: 0.7452917, Avg_loss: 0.8000287134200335, EvalMAE: [  1.2932307   6.6266966  45.27566    87.648575   94.00476   191.88496    53.885666    4.3978643], AtomMAE: [ 0.07598674  0.39586204  2.769402    5.1205387   5.3762245  10.541361    3.18947     0.26172012], Evalloss: 0.744278073310852\n",
      "Epoch: 3, Step: 96, Learning_rate: 3.3191627e-05, Last_Loss: 0.6801481, Avg_loss: 0.7273177430033684, EvalMAE: [  1.2349188   5.7119527  42.25477    82.20943    91.34517   187.2038     44.88445     3.776333 ], AtomMAE: [ 0.07238529  0.33933985  2.5631635   4.847484    5.2437115  10.155171    2.7004845   0.2220696 ], Evalloss: 0.6828868389129639\n",
      "Epoch: 4, Step: 128, Learning_rate: 4.437197e-05, Last_Loss: 0.5953164, Avg_loss: 0.6627427469938993, EvalMAE: [  1.1655052   4.6168737  42.27237    77.76101    89.0306    178.15413    29.291348    2.9880452], AtomMAE: [0.06881031 0.27175498 2.5554192  4.5761414  5.1459417  9.668484   1.680477   0.16993114], Evalloss: 0.610909640789032\n",
      "Epoch: 5, Step: 160, Learning_rate: 5.5552304e-05, Last_Loss: 0.6420445, Avg_loss: 0.5844806600362062, EvalMAE: [  1.1080328   3.3406644  38.27096    72.17727    85.37679   166.633      18.394579    2.1148696], AtomMAE: [0.06533822 0.19855152 2.2799964  4.147756   4.9857426  9.085913   1.0709132  0.11902528], Evalloss: 0.5298410654067993\n",
      "Epoch: 6, Step: 192, Learning_rate: 6.6732646e-05, Last_Loss: 0.44723782, Avg_loss: 0.5265091108158231, EvalMAE: [  1.0813898   2.8325803  38.771576   70.069275   81.81321   147.10164    16.58376     1.9922956], AtomMAE: [0.06438599 0.17151271 2.3213203  4.068245   4.818325   8.092873   0.9700901  0.11263489], Evalloss: 0.5013766288757324\n",
      "Epoch: 7, Step: 224, Learning_rate: 7.791298e-05, Last_Loss: 0.48909324, Avg_loss: 0.4962798496708274, EvalMAE: [  1.034331    2.411187   35.77923    71.63514    78.6023    131.75517    16.253054    2.2487566], AtomMAE: [0.06156413 0.1454563  2.1290495  4.1329026  4.634448   7.3314805  0.92166775 0.12514809], Evalloss: 0.4808410108089447\n",
      "Epoch: 8, Step: 256, Learning_rate: 8.9093315e-05, Last_Loss: 0.4865139, Avg_loss: 0.47159935161471367, EvalMAE: [  1.0146574   2.337717   34.23776    71.79811    80.58784   113.800476   15.992944    1.7557149], AtomMAE: [0.0603363  0.13995764 2.0584826  4.1177306  4.7361355  6.2936735  0.92921084 0.09859011], Evalloss: 0.4584001302719116\n",
      "Training Fininshed!\n",
      "Training Time: 00:00:39\n"
     ]
    }
   ],
   "source": [
    "np.set_printoptions(linewidth=200)\n",
    "\n",
    "print(\"Start training ...\")\n",
    "beg_time = time.time()\n",
    "model.train(n_epoch, ds_train, callbacks=[\n",
    "    record_cb, ckpoint_cb], dataset_sink_mode=False)\n",
    "end_time = time.time()\n",
    "used_time = end_time - beg_time\n",
    "m, s = divmod(used_time, 60)\n",
    "h, m = divmod(m, 60)\n",
    "print(\"Training Fininshed!\")\n",
    "print(\"Training Time: %02d:%02d:%02d\" % (h, m, s))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.5 ('mindsponge')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "2496ecc683137a232cae2452fbbdd53dab340598b6e499c8995be760f3a431b4"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
